Skip to content

[Cute,Fwd,Sm100] fp8 e4m3 and e5m2 support#2109

Open
dcw02 wants to merge 13 commits intoDao-AILab:mainfrom
modal-labs:fa4_fp8_support
Open

[Cute,Fwd,Sm100] fp8 e4m3 and e5m2 support#2109
dcw02 wants to merge 13 commits intoDao-AILab:mainfrom
modal-labs:fa4_fp8_support

Conversation

@dcw02
Copy link

@dcw02 dcw02 commented Dec 29, 2025

Summary

Adds FP8 (E4M3/E5M2) support to Flash Attention 4 on SM100 (Blackwell GPUs). This brings FP8 inference acceleration to the CuTe-DSL implementation with full correctness validation.

What's New

FP8 Forward Pass

  • Supported types: FP8 E4M3fn and E5M2 for Q/K/V inputs
  • Output format: BF16
  • Descaling: Optional per-(batch, kv_head) descale tensors for Q/K/V
  • Hardware: SM100 only
  • Note: Forward-only

Key Implementation Details

Numerical stability fixes:

  • Added max-offset scaling (+8.0 offset, 256x row_sum compensation) to prevent underflow in FP8's limited dynamic range
  • Descales are absorbed into softmax scaling: softmax_scale_eff = softmax_scale * (q_descale * k_descale)
  • V descale applied during final output normalization

PTX updates:

  • Generalized MMA instruction kind selection (was hardcoded to f16, now dynamic: f16, f8f6f4, etc.)
  • All tcgen05.mma inline assembly uses correct FP8 variants

Interface:

_flash_attn_fwd(q, k, v,
    q_descale=None,  # (batch, num_heads_kv) float32
    k_descale=None,  # (batch, num_heads_kv) float32
    v_descale=None,  # (batch, num_heads_kv) float32
    ...
)

DLPack workaround: PyTorch 2.9.1 doesn't support FP8 via DLPack, so we export as uint8 and override the element type.

Benchmark & Testing

Added FP8 benchmark (modeled after FA3's) that tests:

  • Multiple batch sizes & sequence lengths (up to 16K)
  • Causal and non-causal attention
  • Head dimensions 64, 128

Correctness checking:

  • Default: FP8 vs BF16 baseline (atol/rtol=0.5, matching FA3)
  • Debug mode: --check-quantization-only to isolate kernel bugs from quantization error

Usage:

python -m flash_attn.cute.benchmark_flash_attention_fp8

Results

### causal=False, headdim=64, batch=32, seqlen=512 ###
Pytorch fwd: 75.40 TFLOPs/s, 0.911 ms
FA4-CuTe-BF16 fwd: 656.98 TFLOPs/s, 0.105 ms
FA4-CuTe-FP8 fwd: 661.32 TFLOPs/s, 0.104 ms
### causal=False, headdim=64, batch=16, seqlen=1024 ###
Pytorch fwd: 87.90 TFLOPs/s, 1.564 ms
FA4-CuTe-BF16 fwd: 774.60 TFLOPs/s, 0.177 ms
FA4-CuTe-FP8 fwd: 767.90 TFLOPs/s, 0.179 ms
### causal=False, headdim=64, batch=8, seqlen=2048 ###
Pytorch fwd: 57.53 TFLOPs/s, 4.778 ms
FA4-CuTe-BF16 fwd: 862.17 TFLOPs/s, 0.319 ms
FA4-CuTe-FP8 fwd: 836.91 TFLOPs/s, 0.328 ms
### causal=False, headdim=64, batch=4, seqlen=4096 ###
Pytorch fwd: 69.10 TFLOPs/s, 7.956 ms
FA4-CuTe-BF16 fwd: 913.79 TFLOPs/s, 0.602 ms
FA4-CuTe-FP8 fwd: 878.58 TFLOPs/s, 0.626 ms
### causal=False, headdim=64, batch=2, seqlen=8192 ###
Pytorch fwd: 61.91 TFLOPs/s, 17.759 ms
FA4-CuTe-BF16 fwd: 940.18 TFLOPs/s, 1.169 ms
FA4-CuTe-FP8 fwd: 901.25 TFLOPs/s, 1.220 ms
### causal=False, headdim=64, batch=1, seqlen=16384 ###
Pytorch fwd: 116.94 TFLOPs/s, 18.805 ms
FA4-CuTe-BF16 fwd: 952.24 TFLOPs/s, 2.309 ms
FA4-CuTe-FP8 fwd: 915.24 TFLOPs/s, 2.403 ms
### causal=True, headdim=64, batch=32, seqlen=512 ###
Pytorch fwd: 20.68 TFLOPs/s, 1.661 ms
FA4-CuTe-BF16 fwd: 274.34 TFLOPs/s, 0.125 ms
FA4-CuTe-FP8 fwd: 288.32 TFLOPs/s, 0.119 ms
### causal=True, headdim=64, batch=16, seqlen=1024 ###
Pytorch fwd: 22.58 TFLOPs/s, 3.044 ms
FA4-CuTe-BF16 fwd: 429.49 TFLOPs/s, 0.160 ms
FA4-CuTe-FP8 fwd: 444.15 TFLOPs/s, 0.155 ms
### causal=True, headdim=64, batch=8, seqlen=2048 ###
Pytorch fwd: 16.45 TFLOPs/s, 8.353 ms
FA4-CuTe-BF16 fwd: 599.76 TFLOPs/s, 0.229 ms
FA4-CuTe-FP8 fwd: 602.30 TFLOPs/s, 0.228 ms
### causal=True, headdim=64, batch=4, seqlen=4096 ###
Pytorch fwd: 19.27 TFLOPs/s, 14.262 ms
FA4-CuTe-BF16 fwd: 747.56 TFLOPs/s, 0.368 ms
FA4-CuTe-FP8 fwd: 732.18 TFLOPs/s, 0.375 ms
### causal=True, headdim=64, batch=2, seqlen=8192 ###
Pytorch fwd: 17.78 TFLOPs/s, 30.919 ms
FA4-CuTe-BF16 fwd: 861.48 TFLOPs/s, 0.638 ms
FA4-CuTe-FP8 fwd: 822.35 TFLOPs/s, 0.669 ms
### causal=True, headdim=64, batch=1, seqlen=16384 ###
Pytorch fwd: 24.23 TFLOPs/s, 45.376 ms
FA4-CuTe-BF16 fwd: 930.28 TFLOPs/s, 1.182 ms
FA4-CuTe-FP8 fwd: 879.39 TFLOPs/s, 1.250 ms
### causal=False, headdim=128, batch=32, seqlen=512 ###
Pytorch fwd: 106.53 TFLOPs/s, 0.645 ms
FA4-CuTe-BF16 fwd: 961.91 TFLOPs/s, 0.071 ms
FA4-CuTe-FP8 fwd: 1090.02 TFLOPs/s, 0.063 ms
### causal=False, headdim=128, batch=16, seqlen=1024 ###
Pytorch fwd: 140.98 TFLOPs/s, 0.975 ms
FA4-CuTe-BF16 fwd: 1191.05 TFLOPs/s, 0.115 ms
FA4-CuTe-FP8 fwd: 1359.13 TFLOPs/s, 0.101 ms
### causal=False, headdim=128, batch=8, seqlen=2048 ###
Pytorch fwd: 105.50 TFLOPs/s, 2.606 ms
FA4-CuTe-BF16 fwd: 1356.85 TFLOPs/s, 0.203 ms
FA4-CuTe-FP8 fwd: 1548.29 TFLOPs/s, 0.178 ms
### causal=False, headdim=128, batch=4, seqlen=4096 ###
Pytorch fwd: 130.21 TFLOPs/s, 4.222 ms
FA4-CuTe-BF16 fwd: 1455.45 TFLOPs/s, 0.378 ms
FA4-CuTe-FP8 fwd: 1684.52 TFLOPs/s, 0.326 ms
### causal=False, headdim=128, batch=2, seqlen=8192 ###
Pytorch fwd: 120.37 TFLOPs/s, 9.135 ms
FA4-CuTe-BF16 fwd: 1515.04 TFLOPs/s, 0.726 ms
FA4-CuTe-FP8 fwd: 1769.24 TFLOPs/s, 0.621 ms
### causal=False, headdim=128, batch=1, seqlen=16384 ###
Pytorch fwd: 229.58 TFLOPs/s, 9.578 ms
FA4-CuTe-BF16 fwd: 1568.55 TFLOPs/s, 1.402 ms
FA4-CuTe-FP8 fwd: 1817.81 TFLOPs/s, 1.210 ms
### causal=True, headdim=128, batch=32, seqlen=512 ###
Pytorch fwd: 33.38 TFLOPs/s, 1.029 ms
FA4-CuTe-BF16 fwd: 434.61 TFLOPs/s, 0.079 ms
FA4-CuTe-FP8 fwd: 450.51 TFLOPs/s, 0.076 ms
### causal=True, headdim=128, batch=16, seqlen=1024 ###
Pytorch fwd: 39.78 TFLOPs/s, 1.728 ms
FA4-CuTe-BF16 fwd: 691.08 TFLOPs/s, 0.099 ms
FA4-CuTe-FP8 fwd: 720.50 TFLOPs/s, 0.095 ms
### causal=True, headdim=128, batch=8, seqlen=2048 ###
Pytorch fwd: 31.12 TFLOPs/s, 4.416 ms
FA4-CuTe-BF16 fwd: 987.49 TFLOPs/s, 0.139 ms
FA4-CuTe-FP8 fwd: 1018.44 TFLOPs/s, 0.135 ms
### causal=True, headdim=128, batch=4, seqlen=4096 ###
Pytorch fwd: 37.10 TFLOPs/s, 7.410 ms
FA4-CuTe-BF16 fwd: 1220.21 TFLOPs/s, 0.225 ms
FA4-CuTe-FP8 fwd: 1288.03 TFLOPs/s, 0.213 ms
### causal=True, headdim=128, batch=2, seqlen=8192 ###
Pytorch fwd: 34.67 TFLOPs/s, 15.856 ms
FA4-CuTe-BF16 fwd: 1388.77 TFLOPs/s, 0.396 ms
FA4-CuTe-FP8 fwd: 1524.63 TFLOPs/s, 0.361 ms
### causal=True, headdim=128, batch=1, seqlen=16384 ###
Pytorch fwd: 47.01 TFLOPs/s, 23.391 ms
FA4-CuTe-BF16 fwd: 1488.27 TFLOPs/s, 0.739 ms
FA4-CuTe-FP8 fwd: 1665.30 TFLOPs/s, 0.660 ms

@johnnynunez
Copy link
Contributor

@dcw02 can you resolve the conflicts?

@drisspg @tridao can you take a look?

@dcw02
Copy link
Author

dcw02 commented Jan 11, 2026

yes, fixed merge conflicts

@Edenzzzz
Copy link

It seems curious to me that FP8 can only get 0.95x - 1.15x speedup given it reduces data movement by half. Is it possible to get some profiling results to see if it's being bottlenecked by SFU/softmax? Thanks.

@dcw02
Copy link
Author

dcw02 commented Jan 30, 2026

It seems curious to me that FP8 can only get 0.95x - 1.15x speedup given it reduces data movement by half. Is it possible to get some profiling results to see if it's being bottlenecked by SFU/softmax? Thanks.

e2e softmax is turned off for fp8 since during my testing it hurt performance. here's numbers with e2e turned off for bf16 as well:

### causal=False, headdim=64, batch=32, seqlen=512 ###
Pytorch fwd: 74.90 TFLOPs/s, 0.918 ms
FA4-CuTe-BF16 fwd: 645.28 TFLOPs/s, 0.106 ms
FA4-CuTe-FP8 fwd: 664.91 TFLOPs/s, 0.103 ms
### causal=False, headdim=64, batch=16, seqlen=1024 ###
Pytorch fwd: 87.53 TFLOPs/s, 1.570 ms
FA4-CuTe-BF16 fwd: 752.96 TFLOPs/s, 0.183 ms
FA4-CuTe-FP8 fwd: 769.66 TFLOPs/s, 0.179 ms
### causal=False, headdim=64, batch=8, seqlen=2048 ###
Pytorch fwd: 56.62 TFLOPs/s, 4.855 ms
FA4-CuTe-BF16 fwd: 824.65 TFLOPs/s, 0.333 ms
FA4-CuTe-FP8 fwd: 839.92 TFLOPs/s, 0.327 ms
### causal=False, headdim=64, batch=4, seqlen=4096 ###
Pytorch fwd: 69.27 TFLOPs/s, 7.937 ms
FA4-CuTe-BF16 fwd: 867.55 TFLOPs/s, 0.634 ms
FA4-CuTe-FP8 fwd: 879.04 TFLOPs/s, 0.625 ms
### causal=False, headdim=64, batch=2, seqlen=8192 ###
Pytorch fwd: 61.79 TFLOPs/s, 17.794 ms
FA4-CuTe-BF16 fwd: 888.56 TFLOPs/s, 1.237 ms
FA4-CuTe-FP8 fwd: 903.13 TFLOPs/s, 1.217 ms
### causal=False, headdim=64, batch=1, seqlen=16384 ###
Pytorch fwd: 117.46 TFLOPs/s, 18.721 ms
FA4-CuTe-BF16 fwd: 898.81 TFLOPs/s, 2.447 ms
FA4-CuTe-FP8 fwd: 916.96 TFLOPs/s, 2.398 ms
### causal=True, headdim=64, batch=32, seqlen=512 ###
Pytorch fwd: 20.41 TFLOPs/s, 1.683 ms
FA4-CuTe-BF16 fwd: 270.92 TFLOPs/s, 0.127 ms
FA4-CuTe-FP8 fwd: 287.03 TFLOPs/s, 0.120 ms
### causal=True, headdim=64, batch=16, seqlen=1024 ###
Pytorch fwd: 22.26 TFLOPs/s, 3.087 ms
FA4-CuTe-BF16 fwd: 424.22 TFLOPs/s, 0.162 ms
FA4-CuTe-FP8 fwd: 441.36 TFLOPs/s, 0.156 ms
### causal=True, headdim=64, batch=8, seqlen=2048 ###
Pytorch fwd: 16.20 TFLOPs/s, 8.482 ms
FA4-CuTe-BF16 fwd: 581.20 TFLOPs/s, 0.236 ms
FA4-CuTe-FP8 fwd: 603.44 TFLOPs/s, 0.228 ms
### causal=True, headdim=64, batch=4, seqlen=4096 ###
Pytorch fwd: 19.12 TFLOPs/s, 14.375 ms
FA4-CuTe-BF16 fwd: 713.45 TFLOPs/s, 0.385 ms
FA4-CuTe-FP8 fwd: 732.41 TFLOPs/s, 0.375 ms
### causal=True, headdim=64, batch=2, seqlen=8192 ###
Pytorch fwd: 17.62 TFLOPs/s, 31.205 ms
FA4-CuTe-BF16 fwd: 810.10 TFLOPs/s, 0.679 ms
FA4-CuTe-FP8 fwd: 821.98 TFLOPs/s, 0.669 ms
### causal=True, headdim=64, batch=1, seqlen=16384 ###
Pytorch fwd: 23.99 TFLOPs/s, 45.830 ms
FA4-CuTe-BF16 fwd: 863.02 TFLOPs/s, 1.274 ms
FA4-CuTe-FP8 fwd: 879.51 TFLOPs/s, 1.250 ms
### causal=False, headdim=128, batch=32, seqlen=512 ###
Pytorch fwd: 105.66 TFLOPs/s, 0.650 ms
FA4-CuTe-BF16 fwd: 927.44 TFLOPs/s, 0.074 ms
FA4-CuTe-FP8 fwd: 1086.92 TFLOPs/s, 0.063 ms
### causal=False, headdim=128, batch=16, seqlen=1024 ###
Pytorch fwd: 140.28 TFLOPs/s, 0.980 ms
FA4-CuTe-BF16 fwd: 1148.11 TFLOPs/s, 0.120 ms
FA4-CuTe-FP8 fwd: 1371.55 TFLOPs/s, 0.100 ms
### causal=False, headdim=128, batch=8, seqlen=2048 ###
Pytorch fwd: 104.08 TFLOPs/s, 2.641 ms
FA4-CuTe-BF16 fwd: 1277.05 TFLOPs/s, 0.215 ms
FA4-CuTe-FP8 fwd: 1559.74 TFLOPs/s, 0.176 ms
### causal=False, headdim=128, batch=4, seqlen=4096 ###
Pytorch fwd: 130.45 TFLOPs/s, 4.214 ms
FA4-CuTe-BF16 fwd: 1362.90 TFLOPs/s, 0.403 ms
FA4-CuTe-FP8 fwd: 1693.36 TFLOPs/s, 0.325 ms
### causal=False, headdim=128, batch=2, seqlen=8192 ###
Pytorch fwd: 120.05 TFLOPs/s, 9.158 ms
FA4-CuTe-BF16 fwd: 1387.40 TFLOPs/s, 0.793 ms
FA4-CuTe-FP8 fwd: 1777.17 TFLOPs/s, 0.619 ms
### causal=False, headdim=128, batch=1, seqlen=16384 ###
Pytorch fwd: 230.78 TFLOPs/s, 9.528 ms
FA4-CuTe-BF16 fwd: 1443.98 TFLOPs/s, 1.523 ms
FA4-CuTe-FP8 fwd: 1823.37 TFLOPs/s, 1.206 ms
### causal=True, headdim=128, batch=32, seqlen=512 ###
Pytorch fwd: 32.99 TFLOPs/s, 1.042 ms
FA4-CuTe-BF16 fwd: 439.63 TFLOPs/s, 0.078 ms
FA4-CuTe-FP8 fwd: 450.71 TFLOPs/s, 0.076 ms
### causal=True, headdim=128, batch=16, seqlen=1024 ###
Pytorch fwd: 39.26 TFLOPs/s, 1.751 ms
FA4-CuTe-BF16 fwd: 696.69 TFLOPs/s, 0.099 ms
FA4-CuTe-FP8 fwd: 721.17 TFLOPs/s, 0.095 ms
### causal=True, headdim=128, batch=8, seqlen=2048 ###
Pytorch fwd: 30.68 TFLOPs/s, 4.480 ms
FA4-CuTe-BF16 fwd: 974.73 TFLOPs/s, 0.141 ms
FA4-CuTe-FP8 fwd: 1017.89 TFLOPs/s, 0.135 ms
### causal=True, headdim=128, batch=4, seqlen=4096 ###
Pytorch fwd: 36.82 TFLOPs/s, 7.466 ms
FA4-CuTe-BF16 fwd: 1189.51 TFLOPs/s, 0.231 ms
FA4-CuTe-FP8 fwd: 1297.18 TFLOPs/s, 0.212 ms
### causal=True, headdim=128, batch=2, seqlen=8192 ###
Pytorch fwd: 34.35 TFLOPs/s, 16.003 ms
FA4-CuTe-BF16 fwd: 1343.47 TFLOPs/s, 0.409 ms
FA4-CuTe-FP8 fwd: 1527.45 TFLOPs/s, 0.360 ms
### causal=True, headdim=128, batch=1, seqlen=16384 ###
Pytorch fwd: 46.56 TFLOPs/s, 23.616 ms
FA4-CuTe-BF16 fwd: 1428.15 TFLOPs/s, 0.770 ms
FA4-CuTe-FP8 fwd: 1666.91 TFLOPs/s, 0.660 ms

mostly looking to just get the initial support in first, then we can optimize more afterwards

@zhc-hpc
Copy link

zhc-hpc commented Feb 11, 2026

截屏2026-02-11 23 12 12 Hi, which version of nvidia-cutlass-dsl are you using? Why does this happen when I run it on my side?

@dcw02
Copy link
Author

dcw02 commented Feb 12, 2026

Hi, which version of nvidia-cutlass-dsl are you using? Why does this happen when I run it on my side?

on my branch nvidia-cutlass-dsl==4.3.5 after uv pip install flash_attn/cute

@tridao
Copy link
Member

tridao commented Feb 13, 2026

I did a brief review and it seems fine. There's more to optimize for fp8 but that's for later.
Can you fix the lint issue and rebase?

window_size_left: Int32 | int | None = None,
window_size_right: Int32 | int | None = None,
learnable_sink: Optional[cute.Tensor] = None,
mQDescale: Optional[cute.Tensor] = None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we put the scales into a small struct?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does putting things in struct make it more difficult to use tvm-ffi?
https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/cute_dsl_general/compile_with_tvm_ffi.html#working-with-named-tuples

You know the tvm-ffi stuff better than i do

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yup named tuple is the way;

class BlockSparseTensors(NamedTuple):

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

put the scales into a named tuple here and updated the wiring

acc_O_mn_row_is_zero_or_nan = row_sum == 0.0 or row_sum != row_sum
stats[stage] = (row_sum, row_max, acc_O_mn_row_is_zero_or_nan)
scale = cute.arch.rcp_approx(row_sum if not acc_O_mn_row_is_zero_or_nan else 1.0)
scale = scale * v_descale
Copy link
Collaborator

@drisspg drisspg Feb 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I doubt it has much of an impact but did we measure that the non qkv scaled doesnt take any performance hit from these changes

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there was no measurable performance impact but I added compile time gating

@dcw02
Copy link
Author

dcw02 commented Feb 14, 2026

fixed lint and rebased but I think torch==2.10.0 broke the pytorch baselines in the benchmarking (this extends to the hopper fp8 benchmark), getting this error:

Traceback (most recent call last):
  File "/home/modal/flash-attention/flash_attn/cute/benchmark_flash_attention_fp8.py", line 434, in <module>
    raise SystemExit(main())
                     ^^^^^^
  File "/home/modal/flash-attention/flash_attn/cute/benchmark_flash_attention_fp8.py", line 297, in main
    out_ref_bf16 = attention_pytorch(qkv_bf16, causal=causal)  # warmup / reference
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/modal/flash-attention/flash_attn/cute/benchmark_flash_attention_fp8.py", line 63, in attention_pytorch
    torch.baddbmm(scores, q, k, beta=0, alpha=softmax_scale), "(b h) t s -> b h t s", h=nheads
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: CUDA error: CUBLAS_STATUS_INVALID_VALUE when calling `cublasGemmStridedBatchedEx(handle, opa, opb, (int)m, (int)n, (int)k, (void*)&falpha, a, CUDA_R_16BF, (int)lda, stridea, b, CUDA_R_16BF, (int)ldb, strideb, (void*)&fbeta, c, std::is_same_v<C_Dtype, float> ? CUDA_R_32F : CUDA_R_16BF, (int)ldc, stridec, (int)num_batches, compute_type, CUBLAS_GEMM_DEFAULT_TENSOR_OP)`

benchmarking script still works if you downgrade to torch==2.9.1

@drisspg
Copy link
Collaborator

drisspg commented Mar 10, 2026

Can you do one more rebase please :)

@howardzhang-cv
Copy link

Tested this fp8 low precision attention implementation by adding it as a quantized overload through PyTorch (pytorch/pytorch#175472), and tested it through the low precision attention API in TorchAO and saw some decent runtime results. You can see the PR here for more implementation details (pytorch/ao#3960 and pytorch/ao#3947)

Results

Single-Layer Results

Results directly comparing bf16 FA4 SDPA versus fp8 FA4 SDPA (including quantization time):
image

Llama3 Model Results

Results comparing Llama3 model with bf16 FA4 SDPA versus Llama3 using the fp8 FA4 SDPA.
Perplexity: 6.19 -> 6.24 (on WikiText2 dataset)
image

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants